#!/usr/bin/env python3

import os, itertools, urllib.request, urllib.parse, urllib.error, urllib.request, urllib.error, urllib.parse, sys


def get_ch_answer(query):
    return (
        urllib.request.urlopen(
            os.environ.get(
                "CLICKHOUSE_URL",
                "http://localhost:" + os.environ.get("CLICKHOUSE_PORT_HTTP", "8123"),
            ),
            data=query.encode(),
        )
        .read()
        .decode()
    )


def check_answers(query, answer):
    ch_answer = get_ch_answer(query)
    if ch_answer.strip() != answer.strip():
        print("FAIL on query:", query)
        print("Expected answer:", answer)
        print("Fetched answer :", ch_answer)
        exit(-1)


def get_values():
    values = [0, 1, -1]
    for bits in [8, 16, 32, 64]:
        values += [2**bits, 2**bits - 1]
        values += [2 ** (bits - 1) - 1, 2 ** (bits - 1), 2 ** (bits - 1) + 1]
        values += [-(2 ** (bits - 1)) - 1, -(2 ** (bits - 1)), -(2 ** (bits - 1)) + 1]
    return values


def is_valid_integer(x):
    return -(2**63) <= x and x <= 2**64 - 1


TEST_WITH_CASTING = True
GENERATE_TEST_FILES = False

TYPES = {
    "UInt8": {"bits": 8, "sign": False, "float": False},
    "Int8": {"bits": 8, "sign": True, "float": False},
    "UInt16": {"bits": 16, "sign": False, "float": False},
    "Int16": {"bits": 16, "sign": True, "float": False},
    "UInt32": {"bits": 32, "sign": False, "float": False},
    "Int32": {"bits": 32, "sign": True, "float": False},
    "UInt64": {"bits": 64, "sign": False, "float": False},
    "Int64": {"bits": 64, "sign": True, "float": False}
    # "Float32" : { "bits" : 32, "sign" : True, "float" : True },
    # "Float64" : { "bits" : 64, "sign" : True, "float" : True }
}


def inside_range(value, type_name):
    bits = TYPES[type_name]["bits"]
    signed = TYPES[type_name]["sign"]
    is_float = TYPES[type_name]["float"]

    if is_float:
        return True

    if signed:
        return -(2 ** (bits - 1)) <= value and value <= 2 ** (bits - 1) - 1
    else:
        return 0 <= value and value <= 2**bits - 1


def test_operators(v1, v2, v1_passed, v2_passed):
    query_str = "{v1} = {v2}, {v1} != {v2}, {v1} < {v2}, {v1} <= {v2}, {v1} > {v2}, {v1} >= {v2},\t".format(
        v1=v1_passed, v2=v2_passed
    )
    query_str += "{v1} = {v2}, {v1} != {v2}, {v1} < {v2}, {v1} <= {v2}, {v1} > {v2}, {v1} >= {v2} ".format(
        v1=v2_passed, v2=v1_passed
    )

    answers = [v1 == v2, v1 != v2, v1 < v2, v1 <= v2, v1 > v2, v1 >= v2]
    answers += [v2 == v1, v2 != v1, v2 < v1, v2 <= v1, v2 > v1, v2 >= v1]

    answers_str = "\t".join([str(int(x)) for x in answers])

    return (query_str, answers_str)


VALUES = [x for x in get_values() if is_valid_integer(x)]


def test_pair(v1, v2):
    query = "SELECT {}, {}, ".format(v1, v2)
    answers = "{}\t{}\t".format(v1, v2)

    q, a = test_operators(v1, v2, str(v1), str(v2))
    query += q
    answers += a

    if TEST_WITH_CASTING:
        for t1 in TYPES.keys():
            if inside_range(v1, t1):
                for t2 in TYPES.keys():
                    if inside_range(v2, t2):
                        q, a = test_operators(
                            v1, v2, "to{}({})".format(t1, v1), "to{}({})".format(t2, v2)
                        )
                        query += ", " + q
                        answers += "\t" + a

    check_answers(query, answers)
    return query, answers


VALUES_INT = [
    0,
    -1,
    1,
    2**64 - 1,
    2**63,
    -(2**63),
    2**63 - 1,
    2**51,
    2**52,
    2**53 - 1,
    2**53,
    2**53 + 1,
    2**53 + 2,
    -(2**53) + 1,
    -(2**53),
    -(2**53) - 1,
    -(2**53) - 2,
    2 * 52,
    -(2**52),
]
VALUES_FLOAT = [
    float(x)
    for x in VALUES_INT
    + [
        -0.5,
        0.5,
        -1.5,
        1.5,
        2**53,
        2**51 - 0.5,
        2**51 + 0.5,
        2**60,
        -(2**60),
        -(2**63) - 10000,
        2**63 + 10000,
    ]
]


def test_float_pair(i, f):
    f_str = "%.9f" % f
    query = "SELECT '{}', '{}', ".format(i, f_str)
    answers = "{}\t{}\t".format(i, f_str)

    q, a = test_operators(i, f, i, f_str)
    query += q
    answers += a

    if TEST_WITH_CASTING:
        for t1 in TYPES.keys():
            if inside_range(i, t1):
                q, a = test_operators(i, f, "to{}({})".format(t1, i), f_str)
                query += ", " + q
                answers += "\t" + a

    check_answers(query, answers)
    return query, answers


def main():
    if GENERATE_TEST_FILES:
        base_name = "00411_accurate_number_comparison"
        sql_file = open(base_name + ".sql", "wt")
        ref_file = open(base_name + ".reference", "wt")

    num_int_tests = len(list(itertools.combinations(VALUES, 2)))

    num_parts = 4
    for part in range(0, num_parts):
        if "int" + str(part + 1) in sys.argv[1:]:
            for v1, v2 in itertools.islice(
                itertools.combinations(VALUES, 2),
                part * num_int_tests // num_parts,
                (part + 1) * num_int_tests // num_parts,
            ):
                q, a = test_pair(v1, v2)
                if GENERATE_TEST_FILES:
                    sql_file.write(q + ";\n")
                    ref_file.write(a + "\n")

    if "float" in sys.argv[1:]:
        for i, f in itertools.product(VALUES_INT, VALUES_FLOAT):
            q, a = test_float_pair(i, f)
            if GENERATE_TEST_FILES:
                sql_file.write(q + ";\n")
                ref_file.write(a + "\n")

    print("PASSED")


if __name__ == "__main__":
    main()
